// Copyright (c) 2023-2025 ParadeDB, Inc.
//
// This file is part of ParadeDB - Postgres for Search and Analytics
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

use crate::api::HashMap;
use crate::gucs;
use crate::postgres::customscan::builders::custom_path::{CustomPathBuilder, Flags};
use crate::postgres::customscan::{CreateUpperPathsHookArgs, CustomScan, RelPathlistHookArgs};
use once_cell::sync::Lazy;
use pgrx::{pg_guard, pg_sys, PgMemoryContexts};
use std::collections::hash_map::Entry;

unsafe fn add_path(rel: *mut pg_sys::RelOptInfo, mut path: pg_sys::CustomPath) {
    let forced = path.flags & Flags::Force as u32 != 0;
    path.flags ^= Flags::Force as u32; // make sure to clear this flag because it's special to us

    let mut custom_path = PgMemoryContexts::CurrentMemoryContext
        .copy_ptr_into(&mut path, std::mem::size_of_val(&path));

    if (*custom_path).path.parallel_aware {
        // add the partial path since the user-generated plan is parallel aware
        pg_sys::add_partial_path(rel, custom_path.cast());

        // remove all the existing possible paths
        (*rel).pathlist = std::ptr::null_mut();

        // then make another copy of it, increase its costs really, really high and
        // submit it as a regular path too, immediately after clearing out all the other
        // existing possible paths.
        //
        // We don't want postgres to choose this path, but we have to have at least one
        // non-partial path available for it to consider
        let copy = PgMemoryContexts::CurrentMemoryContext
            .copy_ptr_into(&mut path, std::mem::size_of_val(&path));
        (*copy).path.parallel_aware = false;
        (*copy).path.total_cost = 1000000000.0;
        (*copy).path.startup_cost = 1000000000.0;

        // will be added down below
        custom_path = copy.cast();
    } else if forced {
        // remove all the existing possible paths
        (*rel).pathlist = std::ptr::null_mut();
    }

    // add this path for consideration
    pg_sys::add_path(rel, custom_path.cast());
}

pub fn register_rel_pathlist<CS>(_: CS)
where
    CS: CustomScan<Args = RelPathlistHookArgs> + 'static,
{
    unsafe {
        static mut PREV_HOOKS: Lazy<HashMap<std::any::TypeId, pg_sys::set_rel_pathlist_hook_type>> =
            Lazy::new(Default::default);

        #[pg_guard]
        extern "C-unwind" fn __priv_callback<CS>(
            root: *mut pg_sys::PlannerInfo,
            rel: *mut pg_sys::RelOptInfo,
            rti: pg_sys::Index,
            rte: *mut pg_sys::RangeTblEntry,
        ) where
            CS: CustomScan<Args = RelPathlistHookArgs> + 'static,
        {
            unsafe {
                #[allow(static_mut_refs)]
                if let Some(Some(prev_hook)) = PREV_HOOKS.get(&std::any::TypeId::of::<CS>()) {
                    (*prev_hook)(root, rel, rti, rte);
                }

                paradedb_rel_pathlist_callback::<CS>(root, rel, rti, rte);
            }
        }

        #[allow(static_mut_refs)]
        match PREV_HOOKS.entry(std::any::TypeId::of::<CS>()) {
            Entry::Occupied(_) => panic!("{} is already registered", std::any::type_name::<CS>()),
            Entry::Vacant(entry) => entry.insert(pg_sys::set_rel_pathlist_hook),
        };

        pg_sys::set_rel_pathlist_hook = Some(__priv_callback::<CS>);

        pg_sys::RegisterCustomScanMethods(CS::custom_scan_methods())
    }
}

/// Although this hook function can be used to examine, modify, or remove paths generated by the
/// core system, a custom scan provider will typically confine itself to generating CustomPath
/// objects and adding them to rel using add_path. The custom scan provider is responsible for
/// initializing the CustomPath object, which is declared like this:
#[pg_guard]
pub extern "C-unwind" fn paradedb_rel_pathlist_callback<CS>(
    root: *mut pg_sys::PlannerInfo,
    rel: *mut pg_sys::RelOptInfo,
    rti: pg_sys::Index,
    rte: *mut pg_sys::RangeTblEntry,
) where
    CS: CustomScan<Args = RelPathlistHookArgs> + 'static,
{
    unsafe {
        if !gucs::enable_custom_scan() {
            return;
        }

        let Some(path) = CS::create_custom_path(CustomPathBuilder::new(
            root,
            rel,
            RelPathlistHookArgs {
                root,
                rel,
                rti,
                rte,
            },
        )) else {
            return;
        };

        add_path(rel, path)
    }
}

pub fn register_upper_path<CS>(_: CS)
where
    CS: CustomScan<Args = CreateUpperPathsHookArgs> + 'static,
{
    unsafe {
        static mut PREV_HOOKS: Lazy<
            HashMap<std::any::TypeId, pg_sys::create_upper_paths_hook_type>,
        > = Lazy::new(Default::default);

        #[pg_guard]
        extern "C-unwind" fn __priv_callback<CS>(
            root: *mut pg_sys::PlannerInfo,
            stage: pg_sys::UpperRelationKind::Type,
            input_rel: *mut pg_sys::RelOptInfo,
            output_rel: *mut pg_sys::RelOptInfo,
            extra: *mut ::std::os::raw::c_void,
        ) where
            CS: CustomScan<Args = CreateUpperPathsHookArgs> + 'static,
        {
            unsafe {
                #[allow(static_mut_refs)]
                if let Some(Some(prev_hook)) = PREV_HOOKS.get(&std::any::TypeId::of::<CS>()) {
                    (*prev_hook)(root, stage, input_rel, output_rel, extra);
                }

                paradedb_upper_paths_callback::<CS>(root, stage, input_rel, output_rel, extra);
            }
        }

        #[allow(static_mut_refs)]
        match PREV_HOOKS.entry(std::any::TypeId::of::<CS>()) {
            Entry::Occupied(_) => panic!("{} is already registered", std::any::type_name::<CS>()),
            Entry::Vacant(entry) => entry.insert(pg_sys::create_upper_paths_hook),
        };

        pg_sys::create_upper_paths_hook = Some(__priv_callback::<CS>);

        pg_sys::RegisterCustomScanMethods(CS::custom_scan_methods())
    }
}

#[pg_guard]
pub extern "C-unwind" fn paradedb_upper_paths_callback<CS>(
    root: *mut pg_sys::PlannerInfo,
    stage: pg_sys::UpperRelationKind::Type,
    input_rel: *mut pg_sys::RelOptInfo,
    output_rel: *mut pg_sys::RelOptInfo,
    extra: *mut ::std::os::raw::c_void,
) where
    CS: CustomScan<Args = CreateUpperPathsHookArgs> + 'static,
{
    if stage != pg_sys::UpperRelationKind::UPPERREL_GROUP_AGG {
        return;
    }

    if !gucs::enable_aggregate_custom_scan() {
        return;
    }

    unsafe {
        let Some(path) = CS::create_custom_path(CustomPathBuilder::new(
            root,
            output_rel,
            CreateUpperPathsHookArgs {
                root,
                stage,
                input_rel,
                output_rel,
                extra,
            },
        )) else {
            return;
        };

        add_path(output_rel, path)
    }
}
